import io
import os, sys
import PIL

import torch.nn.functional as F

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import torchvision.utils
from scripts.inference import get_model as get_dvae

target_image_size = 224
def preprocess(img):
    s = min(img.size)
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0) * 255
    return img

TOKENIZER_PATH='/home/v-yuxinfang/container/beit/tokenizer/PeCo/OUTPUT/dVAE/vqvae_img224_ds16_ch64_lr2e-5_pixel1_vitb_percep0.08_emaVQ_code8192_256/checkpoint/000067e_1387915iter.pth'
d_vae = get_dvae(TOKENIZER_PATH).cuda().eval()

n_parameters = sum(p.numel() for p in d_vae.parameters() if p.requires_grad)
print('number of params (M):', n_parameters / 1000000)

IM_PATH = '/home/v-yuxinfang/container/beit/tokenizer/PeCo/OUTPUT/TEST/Ori'
imgs = os.listdir(IM_PATH)
out = []

for img in imgs:
    print (img)
    x = preprocess(PIL.Image.open(os.path.join(IM_PATH, img))).cuda()
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            z = d_vae.get_tokens(x)['token']
            print(z.shape)
            x_rec = d_vae.decode(z)
            print(x_rec.shape)
            exit()
    x = torch.cat([x, x_rec], dim=2)

    out.append(x)
out = torch.cat(out, dim=0) / 255.

torchvision.utils.save_image(
    out,
    os.path.join('./OUTPUT/TEST/Out', 'recon.jpg'),
    padding=0,
    normalize=False)


    
